Importing Modules¶

In [1]:
import pandas as pd # for data analysis
import numpy as np #for numerical operations
import matplotlib.pyplot as plt
import hvplot.pandas

Load Dataset¶

In [2]:
data=pd.read_csv(r"C:\Users\naga_\3. Coding\Heart Disease Code Implementation\HeartAttack.csv",na_values='?')
data.head()
Out[2]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num
0 28 1 2 130.0 132.0 0.0 2.0 185.0 0.0 0.0 NaN NaN NaN 0
1 29 1 2 120.0 243.0 0.0 0.0 160.0 0.0 0.0 NaN NaN NaN 0
2 29 1 2 140.0 NaN 0.0 0.0 170.0 0.0 0.0 NaN NaN NaN 0
3 30 0 1 170.0 237.0 0.0 1.0 170.0 0.0 0.0 NaN NaN 6.0 0
4 31 0 2 100.0 219.0 0.0 1.0 150.0 0.0 0.0 NaN NaN NaN 0
In [3]:
data.head()
Out[3]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num
0 28 1 2 130.0 132.0 0.0 2.0 185.0 0.0 0.0 NaN NaN NaN 0
1 29 1 2 120.0 243.0 0.0 0.0 160.0 0.0 0.0 NaN NaN NaN 0
2 29 1 2 140.0 NaN 0.0 0.0 170.0 0.0 0.0 NaN NaN NaN 0
3 30 0 1 170.0 237.0 0.0 1.0 170.0 0.0 0.0 NaN NaN 6.0 0
4 31 0 2 100.0 219.0 0.0 1.0 150.0 0.0 0.0 NaN NaN NaN 0
In [4]:
data.tail()
Out[4]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num
289 52 1 4 160.0 331.0 0.0 0.0 94.0 1.0 2.5 NaN NaN NaN 1
290 54 0 3 130.0 294.0 0.0 1.0 100.0 1.0 0.0 2.0 NaN NaN 1
291 56 1 4 155.0 342.0 1.0 0.0 150.0 1.0 3.0 2.0 NaN NaN 1
292 58 0 2 180.0 393.0 0.0 0.0 110.0 1.0 1.0 2.0 NaN 7.0 1
293 65 1 4 130.0 275.0 0.0 1.0 115.0 1.0 1.0 2.0 NaN NaN 1
In [5]:
print("Number of records in dataset:",len(data))
Number of records in dataset: 294
In [6]:
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 294 entries, 0 to 293
Data columns (total 14 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   age         294 non-null    int64  
 1   sex         294 non-null    int64  
 2   cp          294 non-null    int64  
 3   trestbps    293 non-null    float64
 4   chol        271 non-null    float64
 5   fbs         286 non-null    float64
 6   restecg     293 non-null    float64
 7   thalach     293 non-null    float64
 8   exang       293 non-null    float64
 9   oldpeak     294 non-null    float64
 10  slope       104 non-null    float64
 11  ca          3 non-null      float64
 12  thal        28 non-null     float64
 13  num         294 non-null    int64  
dtypes: float64(10), int64(4)
memory usage: 32.3 KB
In [7]:
data.describe()
Out[7]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num
count 294.000000 294.000000 294.000000 293.000000 271.000000 286.000000 293.000000 293.000000 293.000000 294.000000 104.000000 3.0 28.000000 294.000000
mean 47.826531 0.724490 2.982993 132.583618 250.848708 0.069930 0.218430 139.129693 0.303754 0.586054 1.894231 0.0 5.642857 0.360544
std 7.811812 0.447533 0.965117 17.626568 67.657711 0.255476 0.460868 23.589749 0.460665 0.908648 0.338995 0.0 1.615074 0.480977
min 28.000000 0.000000 1.000000 92.000000 85.000000 0.000000 0.000000 82.000000 0.000000 0.000000 1.000000 0.0 3.000000 0.000000
25% 42.000000 0.000000 2.000000 120.000000 209.000000 0.000000 0.000000 122.000000 0.000000 0.000000 2.000000 0.0 5.250000 0.000000
50% 49.000000 1.000000 3.000000 130.000000 243.000000 0.000000 0.000000 140.000000 0.000000 0.000000 2.000000 0.0 6.000000 0.000000
75% 54.000000 1.000000 4.000000 140.000000 282.500000 0.000000 0.000000 155.000000 1.000000 1.000000 2.000000 0.0 7.000000 1.000000
max 66.000000 1.000000 4.000000 200.000000 603.000000 1.000000 2.000000 190.000000 1.000000 5.000000 3.000000 0.0 7.000000 1.000000
In [8]:
#Image(filename = 'nlp_frames/lm/sm1.png')
dictionary = {
                "age     ": "Age of person",
                "sex     ": "1=male, 0=female",
                "cp      ": "Chest pain type, 1: typical angina, 2: atypical angina, 3: non-anginal pain, 4: asymptomatic",
                "trestbps": "Resting blood pressure",
                "chol    ": "Serum cholestoral in mg/dl",
                "fbs     ": "Fasting blood sugar > 120 mg/dl",
                "thalach ": "Maximum heart rate achieved",
                "restecg ": "Resting electrocardiographic results (values 0,1,2)",
                "exang   ": "Exercise induced angina",
                "oldpeak ": "Oldpeak = ST depression induced by exercise relative to rest",
                "slope   ": "The slope of the peak exercise ST segment",
                "ca      ": "Number of major vessels (0-3) colored by flourosopy",
                "thal    ": "Thalassemia  3 = normal; 6 = fixed defect; 7 = reversable defect"
}
for i in dictionary:
    print(i, ": ",dictionary[i])
age      :  Age of person
sex      :  1=male, 0=female
cp       :  Chest pain type, 1: typical angina, 2: atypical angina, 3: non-anginal pain, 4: asymptomatic
trestbps :  Resting blood pressure
chol     :  Serum cholestoral in mg/dl
fbs      :  Fasting blood sugar > 120 mg/dl
thalach  :  Maximum heart rate achieved
restecg  :  Resting electrocardiographic results (values 0,1,2)
exang    :  Exercise induced angina
oldpeak  :  Oldpeak = ST depression induced by exercise relative to rest
slope    :  The slope of the peak exercise ST segment
ca       :  Number of major vessels (0-3) colored by flourosopy
thal     :  Thalassemia  3 = normal; 6 = fixed defect; 7 = reversable defect

Remove NULL values¶

In [9]:
#couting the number of null values in each column
print("Count of each column's null values\n",data.isnull().sum())
Count of each column's null values
 age             0
sex             0
cp              0
trestbps        1
chol           23
fbs             8
restecg         1
thalach         1
exang           1
oldpeak         0
slope         190
ca            291
thal          266
num             0
dtype: int64
In [10]:
# slope,ca,thal these columns have more than 50% null values,so we are removing these columns
data = data.drop(columns = ["slope","ca","thal"])
print("Now the remainiing columns are:")
for i in data.columns:
    print(i,end="\t")
Now the remainiing columns are:
age	sex	cp	trestbps	chol	fbs	restecg	thalach	exang	oldpeak	num       	
In [11]:
#dropping the remaining null values
data=data.dropna()
print(data.isnull().sum())
age           0
sex           0
cp            0
trestbps      0
chol          0
fbs           0
restecg       0
thalach       0
exang         0
oldpeak       0
num           0
dtype: int64
In [12]:
#renaming the output column from num to target
data = data.rename(columns={"num       ":"target"})
data.head()
Out[12]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak target
0 28 1 2 130.0 132.0 0.0 2.0 185.0 0.0 0.0 0
1 29 1 2 120.0 243.0 0.0 0.0 160.0 0.0 0.0 0
3 30 0 1 170.0 237.0 0.0 1.0 170.0 0.0 0.0 0
4 31 0 2 100.0 219.0 0.0 1.0 150.0 0.0 0.0 0
5 32 0 2 105.0 198.0 0.0 0.0 165.0 0.0 0.0 0

Check for duplicates¶

In [13]:
print("Index  Existence of duplicate value\n")
print(data.duplicated())
print("Count of Duplicate rows : ",data.duplicated().sum())
print("Count of Non-Duplicate rows : ",(~data.duplicated()).sum())
data.loc[data.duplicated(),:]
Index  Existence of duplicate value

0      False
1      False
3      False
4      False
5      False
       ...  
289    False
290    False
291    False
292    False
293    False
Length: 261, dtype: bool
Count of Duplicate rows :  0
Count of Non-Duplicate rows :  261
Out[13]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak target

All the categorical columns & their values¶

In [14]:
#to see particular column data use the below method
data["cp"].value_counts()
Out[14]:
4    113
2     92
3     46
1     10
Name: cp, dtype: int64
In [15]:
data["exang"].value_counts()
Out[15]:
0.0    178
1.0     83
Name: exang, dtype: int64
In [16]:
data["fbs"].value_counts()
Out[16]:
0.0    242
1.0     19
Name: fbs, dtype: int64
In [17]:
data["restecg"].value_counts()
Out[17]:
0.0    208
1.0     47
2.0      6
Name: restecg, dtype: int64
In [18]:
data["sex"].value_counts()
Out[18]:
1    192
0     69
Name: sex, dtype: int64

Patient Records¶

In [19]:
had_disease = data[data['target'] == 0]
didnot_had_disease = data[data['target'] == 1]
print("Number of patients that had heart disease:",len(had_disease))
print("Number of patients that had heart disease:",len(didnot_had_disease))
Number of patients that had heart disease: 163
Number of patients that had heart disease: 98
In [20]:
print("Patient records that had heart disease: \n",had_disease)
print("Patient records that didn't had heart disease: \n",didnot_had_disease)
Patient records that had heart disease: 
      age  sex  cp  trestbps   chol  fbs  restecg  thalach  exang  oldpeak  \
0     28    1   2     130.0  132.0  0.0      2.0    185.0    0.0      0.0   
1     29    1   2     120.0  243.0  0.0      0.0    160.0    0.0      0.0   
3     30    0   1     170.0  237.0  0.0      1.0    170.0    0.0      0.0   
4     31    0   2     100.0  219.0  0.0      1.0    150.0    0.0      0.0   
5     32    0   2     105.0  198.0  0.0      0.0    165.0    0.0      0.0   
..   ...  ...  ..       ...    ...  ...      ...      ...    ...      ...   
183   60    1   3     120.0  246.0  0.0      2.0    135.0    0.0      0.0   
184   61    0   4     130.0  294.0  0.0      1.0    120.0    1.0      1.0   
185   61    1   4     125.0  292.0  0.0      1.0    115.0    1.0      0.0   
186   62    0   1     160.0  193.0  0.0      0.0    116.0    0.0      0.0   
187   62    1   2     140.0  271.0  0.0      0.0    152.0    0.0      1.0   

     target  
0         0  
1         0  
3         0  
4         0  
5         0  
..      ...  
183       0  
184       0  
185       0  
186       0  
187       0  

[163 rows x 11 columns]
Patient records that didn't had heart disease: 
      age  sex  cp  trestbps   chol  fbs  restecg  thalach  exang  oldpeak  \
188   31    1   4     120.0  270.0  0.0      0.0    153.0    1.0      1.5   
189   33    0   4     100.0  246.0  0.0      0.0    150.0    1.0      1.0   
190   34    1   1     140.0  156.0  0.0      0.0    180.0    0.0      0.0   
191   35    1   2     110.0  257.0  0.0      0.0    140.0    0.0      0.0   
192   36    1   2     120.0  267.0  0.0      0.0    160.0    0.0      3.0   
..   ...  ...  ..       ...    ...  ...      ...      ...    ...      ...   
289   52    1   4     160.0  331.0  0.0      0.0     94.0    1.0      2.5   
290   54    0   3     130.0  294.0  0.0      1.0    100.0    1.0      0.0   
291   56    1   4     155.0  342.0  1.0      0.0    150.0    1.0      3.0   
292   58    0   2     180.0  393.0  0.0      0.0    110.0    1.0      1.0   
293   65    1   4     130.0  275.0  0.0      1.0    115.0    1.0      1.0   

     target  
188       1  
189       1  
190       1  
191       1  
192       1  
..      ...  
289       1  
290       1  
291       1  
292       1  
293       1  

[98 rows x 11 columns]
In [21]:
data.columns
Out[21]:
Index(['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach',
       'exang', 'oldpeak', 'target'],
      dtype='object')

Bar graphs¶

In [22]:
data.target.value_counts().hvplot.bar(
    title="Heart Disease Count", xlabel='Heart Disease', ylabel='Count', 
    width=500, height=350
)
Out[22]:
In [23]:
data.fbs.value_counts().hvplot.bar(
    title="Heart Disease Count for fbs", xlabel='Heart Disease', ylabel='Count', color='yellow',
    width=500, height=350
)
Out[23]:
In [24]:
data.cp.value_counts().hvplot.bar(
    title="Heart Disease Count for CP", xlabel='Chest Pain Type', ylabel='Count', color='white',
    width=500, height=350
)
Out[24]:
In [25]:
data.trestbps.value_counts().hvplot.bar(
    title="Hemoglobin", xlabel='Hemoglobin', ylabel='Count', color='green',
    width=1200, height=350
)
Out[25]:
In [26]:
data.chol.value_counts().hvplot.bar(
    title="Chol", xlabel='Hemoglobin', ylabel='Count', color='lightblue',
    width=5000, height=250
)
Out[26]:
In [27]:
data.oldpeak.value_counts().hvplot.bar(
    title="Oldpeak", xlabel='Hemoglobin', ylabel='Count', color='green',
    width=500, height=450
)
Out[27]:

One-hot encoding¶

In [28]:
data=pd.get_dummies(data,columns = ["cp","restecg"])
numerical_cols=["age","trestbps","chol","thalach","oldpeak"]
cat_cols = list(set(data.columns)-set(numerical_cols)-{"target"})
print(numerical_cols)
print(cat_cols)
['age', 'trestbps', 'chol', 'thalach', 'oldpeak']
['cp_1', 'fbs', 'sex', 'restecg_2.0', 'cp_3', 'restecg_1.0', 'cp_2', 'cp_4', 'exang', 'restecg_0.0']

Scaling Numerical Data using Standardization¶

In [29]:
# formula { y = (x – mean) / standard_deviation }
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
In [30]:
def my_fun(data,numerical_cols,cat_cols,scaler):
    x_scaled = scaler.fit_transform(data[numerical_cols])
    x_cat = data[cat_cols].to_numpy()
    x = np.hstack((x_cat,x_scaled))
    y = data["target"]
    return x,y
data_x,data_y = my_fun(data,numerical_cols,cat_cols,scaler)

splitting the data into train and test¶

In [31]:
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test =train_test_split(data_x,data_y,test_size = 0.2,random_state=54)
In [32]:
X_test
Out[32]:
array([[ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  7.95927016e-01, -7.15089404e-01,
        -1.65142113e-01,  6.23061871e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  2.84889006e-01,  4.16394891e-01,
        -2.71943755e-01,  3.26396520e-02,  4.71735305e+00],
       [ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00, -1.63150353e+00, -7.15089404e-01,
        -1.35521756e+00,  1.93042536e+00, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  7.95927016e-01, -1.28083155e+00,
        -6.22863438e-01,  1.16985683e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         1.00000000e+00, -7.37187015e-01, -7.15089404e-01,
        -3.17715888e-01,  5.38715840e-01,  9.54427786e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -4.81668010e-01,  9.82137038e-01,
         2.48964157e+00,  1.29783012e+00, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  5.40408011e-01, -1.49347257e-01,
        -1.05007001e+00,  3.26396520e-02,  9.54427786e-01],
       [ 1.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -1.75926304e+00,  4.16394891e-01,
        -1.41624707e+00,  1.71956028e+00, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -1.12046552e+00,  3.24510563e+00,
        -1.19369980e-01, -1.40124288e+00, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  2.84889006e-01,  4.16394891e-01,
         1.40636777e+00, -5.99955583e-01,  2.02954929e+00],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  7.95927016e-01,  3.81084777e+00,
        -7.75437213e-01,  1.16985683e-01,  1.49198854e+00],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  4.12648508e-01, -4.32218330e-01,
        -9.28010988e-01,  2.43504730e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -1.12046552e+00, -7.15089404e-01,
         1.37585302e+00,  1.29783012e+00, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -8.64946517e-01, -7.15089404e-01,
         7.04528405e-01,  1.29783012e+00, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  9.23686518e-01, -1.28083155e+00,
         1.45213990e+00,  8.76099965e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  2.93700006e-02,  1.54787919e+00,
         1.61997106e+00, -1.69645399e+00,  1.49198854e+00],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -1.24822502e+00, -7.15089404e-01,
         5.06182497e-01,  1.29783012e+00, -6.58254468e-01],
       [ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00, -2.27030105e+00,  2.11362133e+00,
        -1.80399490e-01,  1.29783012e+00, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  6.68167513e-01, -7.15089404e-01,
        -8.21209346e-01,  3.26396520e-02, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -3.53908507e-01,  4.16394891e-01,
        -3.78745398e-01,  2.01331715e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -6.09427512e-01, -7.15089404e-01,
        -7.29665081e-01,  1.08696504e+00, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00, -2.26149005e-01, -7.15089404e-01,
        -2.71943755e-01, -1.02168574e+00, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  6.68167513e-01,  4.16394891e-01,
         1.08596284e+00,  9.60445996e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  5.40408011e-01, -1.49347257e-01,
        -3.63488020e-01, -8.10820661e-01,  1.49198854e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  2.84889006e-01,  4.16394891e-01,
         5.97726763e-01,  3.26396520e-02, -6.58254468e-01],
       [ 0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  2.20128154e+00,  4.16394891e-01,
         8.72359558e-01, -2.20253018e+00,  9.54427786e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00, -1.24822502e+00, -2.29916742e+00,
        -2.01128479e+00, -2.20398442e-01,  2.02954929e+00],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00, -1.37598453e+00,  4.16394891e-01,
        -6.38120816e-01, -3.89090505e-01,  9.54427786e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  2.93700006e-02, -7.15089404e-01,
         5.36697252e-01, -8.10820661e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  9.23686518e-01,  4.16394891e-01,
        -8.05951968e-01,  4.54369808e-01, -6.58254468e-01],
       [ 0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  1.05144602e+00,  9.82137038e-01,
        -5.46576551e-01, -5.99955583e-01,  4.16867035e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  1.05144602e+00, -1.49347257e-01,
        -9.89040498e-01, -1.65428097e+00, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -2.26149005e-01, -8.28237833e-01,
        -9.58525743e-01, -6.42128598e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.57129503e-01, -1.49347257e-01,
        -6.38120816e-01, -1.78225426e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -2.01478204e+00, -1.56370262e+00,
        -7.75437213e-01,  1.08696504e+00, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00, -2.14254154e+00, -7.15089404e-01,
         3.23093967e-01,  5.80888855e-01,  9.54427786e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  6.68167513e-01, -1.49347257e-01,
        -1.01955525e+00,  3.70023777e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -6.09427512e-01,  9.82137038e-01,
        -9.58525743e-01,  6.23061871e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  5.40408011e-01, -1.49347257e-01,
         7.50300538e-01, -1.23255082e+00,  4.16867035e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -1.63150353e+00,  4.16394891e-01,
        -1.24841592e+00,  4.54369808e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  4.12648508e-01, -1.28083155e+00,
        -8.97496233e-01, -8.10820661e-01, -6.58254468e-01],
       [ 1.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00, -6.09427512e-01, -7.15089404e-01,
         6.43498895e-01,  6.65234887e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  4.12648508e-01,  1.33523817e-01,
        -1.35521756e+00,  4.54369808e-01,  1.49198854e+00],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  1.57129503e-01, -1.84657370e+00,
         6.37185497e-02,  1.46652218e+00, -6.58254468e-01],
       [ 1.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00, -9.83895020e-02, -1.28083155e+00,
         2.68903971e-03,  4.54369808e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  5.40408011e-01,  1.54787919e+00,
        -4.30830928e-02, -2.41339526e+00,  3.64223154e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  2.93700006e-02,  3.03246461e-01,
        -5.31319173e-01, -1.31689685e+00,  9.54427786e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  7.95927016e-01, -7.15089404e-01,
        -4.85547041e-01, -9.38793950e-02, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  6.68167513e-01,  2.67936348e+00,
         5.51954630e-01, -8.10820661e-01,  9.54427786e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  1.43472453e+00,  4.16394891e-01,
         5.82469385e-01,  4.54369808e-01, -6.58254468e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  5.40408011e-01, -1.16768312e+00,
         1.42162515e+00, -1.82297304e+00,  4.16867035e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  1.57129503e-01,  1.54787919e+00,
        -1.05007001e+00,  7.07407902e-01,  4.16867035e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  2.93700006e-02,  1.54787919e+00,
         1.22327924e+00, -1.99166510e+00,  9.54427786e-01]])
In [33]:
len(X_train),len(y_train),len(X_test),len(y_test)
Out[33]:
(208, 208, 53, 53)

Support Vector Machine¶

Training the model¶

In [34]:
from sklearn.svm import SVC
from sklearn.metrics import mean_squared_error,accuracy_score
svm_clf = SVC(kernel='rbf',C=2)
svm_clf.fit(X_train,y_train)
svm_clas = SVC(kernel='rbf',C=2)
svm_clas.fit(X_train,y_train)
Out[34]:
SVC(C=2)

Testing the model¶

In [35]:
y_predict = svm_clf.predict(X_test)
y_predic = svm_clas.predict(X_test)
In [36]:
t = mean_squared_error(y_test,y_predic)
print('Mean squred error is : ',t)
svm_accuracy = accuracy_score(y_test,y_predic)
print("Accuracy of the SVM model is : {0:.2f}".format(svm_accuracy*100),'%')
Mean squred error is :  0.1509433962264151
Accuracy of the SVM model is : 84.91 %
In [37]:
a = mean_squared_error(y_test,y_predict)
print('Mean squred error is : ',a)
svm_accuracy_1 = accuracy_score(y_test,y_predict)
print("Accuracy of the SVM model is : {0:.2f}".format(svm_accuracy_1*100),'%')
Mean squred error is :  0.1509433962264151
Accuracy of the SVM model is : 84.91 %

Miscalssified samples of SVM¶

In [38]:
misclassified_samples = X_test[y_test != y_predict]
print("Misclassified data : \n\n",misclassified_samples)
Misclassified data : 

 [[ 0.          0.          1.          0.          1.          0.
   0.          0.          1.          1.         -0.73718701 -0.7150894
  -0.31771589  0.53871584  0.95442779]
 [ 1.          0.          1.          0.          0.          0.
   0.          0.          0.          1.         -1.75926304  0.41639489
  -1.41624707  1.71956028 -0.65825447]
 [ 0.          0.          1.          0.          0.          0.
   0.          1.          0.          1.         -1.24822502 -0.7150894
   0.5061825   1.29783012 -0.65825447]
 [ 0.          0.          0.          0.          1.          0.
   0.          0.          1.          1.          0.28488901  0.41639489
   0.59772676  0.03263965 -0.65825447]
 [ 0.          0.          1.          0.          0.          0.
   0.          1.          0.          1.         -0.226149   -0.82823783
  -0.95852574 -0.6421286  -0.65825447]
 [ 1.          0.          1.          0.          0.          1.
   0.          0.          0.          0.         -0.60942751 -0.7150894
   0.6434989   0.66523489 -0.65825447]
 [ 0.          0.          1.          0.          1.          0.
   0.          0.          0.          1.          0.41264851  0.13352382
  -1.35521756  0.45436981  1.49198854]
 [ 0.          0.          0.          0.          1.          0.
   0.          0.          0.          1.          0.1571295   1.54787919
  -1.05007001  0.7074079   0.41686703]]
In [39]:
type(misclassified_samples)
Out[39]:
numpy.ndarray
In [40]:
indices_of_misclassified_samples =np.flatnonzero(y_test != y_predict)
print("Indices of misclassified samples",indices_of_misclassified_samples)

pred_misclassified = svm_clf.predict(misclassified_samples)
Indices of misclassified samples [ 4  7 16 24 32 41 42 51]
In [41]:
p=[]
for i in indices_of_misclassified_samples:
    p.append(y_predict[i])
print(p)
q = np.array(p)

print("Actual values : ",q)
print("Predicted values : ",pred_misclassified)
[1, 0, 0, 0, 0, 0, 0, 0]
Actual values :  [1 0 0 0 0 0 0 0]
Predicted values :  [1 0 0 0 0 0 0 0]
In [42]:
accuracy = accuracy_score(q,pred_misclassified)
print("Accuracy of the SVM model is : {0:.2f}".format(accuracy*100),'%')
Accuracy of the SVM model is : 100.00 %

Plotiing the SVM classifier¶

In [43]:
import matplotlib.colors as colors
from sklearn.decomposition import PCA
pca = PCA()
In [44]:
X_train_pca = pca.fit_transform(X_train)
In [45]:
X_train_pca[0],X_train_pca[1]
Out[45]:
(array([ 1.15452653e-01,  1.39028668e-01, -8.47378423e-02, -3.92291575e-01,
        -1.49485733e+00, -7.61597304e-01,  4.87393314e-01,  4.34839849e-02,
         3.95942339e-01, -2.20953508e-01, -9.82187401e-02, -1.35618675e-02,
        -1.29583302e-01, -7.08363565e-16,  1.38164295e-15]),
 array([ 1.42810252e+00, -7.93852886e-01,  9.89126277e-02, -4.67171192e-01,
         4.89833764e-01,  4.49438408e-02, -5.13339970e-01, -3.56621451e-01,
         5.40032777e-02,  4.69187987e-01, -1.47321097e-01, -2.76163541e-03,
         5.41683179e-02, -8.84074313e-16, -6.16666601e-17]))
In [46]:
# reduce the value and make it as an array
a.ravel()
Out[46]:
array([0.1509434])
In [47]:
pc1 = X_train_pca[:, 0] 
pc2 = X_train_pca[:, 1]
svm_clf.fit(np.column_stack((pc1, pc2)), y_train)
x_min = pc1.min() - 1
x_max = pc1.max() + 1
y_min = pc2.min() - 1
y_max = pc2.max() + 1
a, b = np.meshgrid(np.arange(start=x_min, stop=x_max, step=0.1),np.arange(start=y_min, stop=y_max, step=0.1))
Z = svm_clf.predict(np.column_stack((a.ravel(), b.ravel()))) ## Array of zeros and ones
Z = Z.reshape(a.shape)
fig, ax = plt.subplots(figsize=(10,10))
ax.contourf(a, b, Z, alpha=0.5)

cmap = colors.ListedColormap(['#4daf4a','#e41a1c'])

scatter = ax.scatter(pc1, pc2, c=y_train, 
               s=100,
               cmap=cmap,
               edgecolors='black',
               alpha=0.7)

legend = ax.legend(scatter.legend_elements()[0], 
                   scatter.legend_elements()[1],
                    loc="upper right")
legend.get_texts()[0].set_text("Don't have heart disease")
legend.get_texts()[1].set_text("Have heart disease")
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_title('Decison surface using the PCA transformed/projected features')
plt.show()
In [48]:
len(pc1),len(pc2)
Out[48]:
(208, 208)
In [49]:
pc1
Out[49]:
array([ 0.11545265,  1.42810252, -1.41543192, -1.42857101, -0.9555822 ,
       -0.68374196,  0.80471621, -0.96908418,  1.90867931,  0.01946177,
        1.64051338, -3.06881372, -1.27018926, -0.07957789, -1.32206018,
        0.38105107,  1.31807723, -0.28063268,  1.89508304,  1.79933062,
        0.81628878,  0.11710698, -0.43304948,  1.63023815, -0.25626239,
        0.02214917,  0.18313682,  0.25564787,  0.90974829,  1.23626023,
       -1.08471065,  1.40623497,  1.12871814, -0.5683718 , -0.14676822,
       -1.45794626, -3.31724669,  0.15043652, -2.43705895,  0.86123486,
        0.81177502,  1.67013749,  1.40543808, -0.72299461,  1.34123971,
       -1.76321414, -1.95758175,  0.25637693,  1.96923591, -0.65280736,
       -0.2161977 ,  0.1098404 , -1.66777205, -2.45454847,  0.13587878,
        0.10146495, -2.09754426, -0.55768274, -1.91709225,  0.34709503,
        2.62248017,  2.58698231,  1.6363709 ,  3.17386317,  0.19853135,
        0.60588114,  2.22550612,  0.03269122, -0.66894801, -2.74945295,
        1.88213748, -0.40362206,  1.75903698, -0.82729379,  1.2888823 ,
       -1.10868775, -0.65752101, -0.99611678,  0.35355766, -0.76482734,
        1.51379669, -1.01979264, -1.40444099,  1.46858175,  0.2654941 ,
        1.80707416,  1.10089791,  0.90442876,  2.28315915, -1.26479395,
       -1.90296486,  0.66241818, -0.14075059, -0.12450113,  2.01826257,
       -1.03270229,  0.73155075, -0.70733297, -0.37183462, -1.33892344,
       -1.00238314, -1.56436624, -1.36693183,  1.07351424,  3.38421053,
       -1.18562794, -1.12346826, -0.07519186, -1.45206569, -0.33287222,
        0.51403768,  1.99290293,  1.49110056, -0.16911046,  2.50389792,
       -1.07370496, -0.53486367, -0.98667085, -0.41683858,  0.89016324,
       -1.22408672, -0.82382924,  3.28319212, -1.42148868,  2.23694103,
        1.46548502,  3.48379544, -2.58124713,  0.90626053, -1.4762942 ,
        0.55864342,  1.24973351, -1.14574842, -0.74058077,  1.56805608,
        1.27749498,  0.33810178,  1.86196089,  1.05068698,  0.65595003,
       -2.23699307,  0.44911285,  1.64454446, -2.57866579,  1.29885782,
       -1.11040031, -1.21053605,  0.60784353, -2.27368971,  1.15968397,
       -0.61546662,  2.5170607 ,  0.65237975,  1.98534483,  0.96450615,
        0.16193757, -0.4255012 , -0.15697487, -0.71932201,  0.16053541,
       -0.19958774,  0.32087466,  1.78276706, -1.08173119, -1.39985006,
        1.63750325, -0.20025918, -1.56482795,  0.26319422, -2.95284383,
       -1.34213744,  1.27656695, -0.56505852, -1.61904457, -0.30853902,
       -0.55876187, -0.00530676, -0.31285635,  0.73151205,  1.22862473,
       -1.84515229, -0.75046766, -1.50693857, -0.93446328, -2.46373231,
       -0.02415711, -0.09856972,  2.98144398, -1.11731278,  0.02011362,
        0.59236785, -1.77716525, -0.85260784,  0.34947614,  2.56862825,
       -1.62047594, -2.24255513, -0.7928943 ,  1.45195013,  1.59850513,
        0.98129132,  0.56460493, -2.01043139, -1.78081029, -2.25361628,
       -0.09727494, -1.62322063, -0.44048328])
In [50]:
pc2
Out[50]:
array([ 1.39028668e-01, -7.93852886e-01,  2.78029806e-01,  9.31359900e-01,
       -5.79570053e-01,  1.89514997e+00, -7.96346692e-01,  7.61959442e-01,
       -3.56320160e-02,  6.81679668e-02,  9.89506407e-01,  6.99442105e-01,
       -1.12269365e+00, -7.40784753e-01, -1.47331018e-02, -8.75919152e-01,
       -5.35675264e-01, -1.85556517e+00, -1.86340222e-02, -4.58529637e-01,
       -1.18611392e+00, -1.13544254e+00, -3.02488370e-01,  6.43253543e-03,
       -4.38544731e-01, -7.22736429e-01,  2.20584812e-01, -7.45806269e-01,
       -2.06516248e-01, -1.05939565e-01,  6.61821633e-02, -4.05185547e-01,
        3.13460807e-01, -6.49878383e-02,  5.15051661e-01,  8.52589521e-01,
       -4.88664099e-01, -5.74739471e-01,  3.58137643e-01, -3.24311246e-01,
       -8.51354830e-01, -8.59785619e-01, -1.89119208e+00,  4.72279163e-01,
       -1.64741863e-01,  6.82101304e-01, -1.10794836e-01, -2.45036670e-01,
       -1.79963211e+00, -9.94231300e-01,  4.52865461e+00, -5.00798355e-01,
       -3.96959297e-01,  1.59349403e+00, -5.09911748e-01, -2.41455172e-02,
       -4.78187778e-05, -9.61006517e-01, -6.81350026e-01, -6.59157339e-01,
        6.27574911e-01, -4.51229487e-01,  7.44860394e-01,  1.44342211e+00,
        6.10131780e-01, -1.60851496e-01,  1.46355721e-01,  8.27101933e-02,
       -1.01370879e+00, -3.74502649e-01, -6.64100151e-01, -6.32277709e-01,
        3.14631319e-01, -7.69358523e-01,  9.49020677e-01,  5.14532469e-01,
       -5.51846562e-01, -5.33022553e-02, -2.20828172e-01,  1.66429602e-02,
        2.32222574e-01, -4.55450761e-01, -1.58529466e-01,  1.83380041e+00,
        1.55592690e+00, -1.64773854e+00, -6.46145870e-02,  1.91045447e-01,
        2.04496288e+00,  1.85584719e-01,  3.77146930e-01, -4.66100642e-01,
       -3.27757037e-01,  1.68057924e-01, -1.70623816e+00,  1.00133523e+00,
       -4.21103848e-01,  1.81505843e-01,  2.65675565e-01, -6.58625295e-01,
       -5.24170072e-01, -4.59906926e-01, -1.14571096e+00,  1.00259435e+00,
       -3.64618962e-01,  6.59806502e-01, -7.53728077e-01, -2.97183044e-01,
        1.01622877e+00,  7.60679255e-01,  5.64888634e-01, -6.84725415e-01,
       -1.94045937e-01, -7.97903095e-01,  1.27195079e+00, -2.55195280e-01,
        2.02993023e-01, -3.67823720e-01, -2.39047381e+00, -5.63878635e-01,
        4.87739563e-01, -4.53675828e-01,  7.99189009e-01,  9.55239183e-01,
       -5.00283721e-01,  3.13655987e+00, -7.17471162e-01, -2.96055935e-01,
       -3.51177991e-01,  3.15255901e-02,  2.74564266e-01, -4.80822678e-01,
       -8.69295511e-01, -1.69972269e-01,  5.44839045e-03,  2.37662352e+00,
       -4.50526110e-01, -5.82668658e-01,  2.68144670e-01,  3.39397261e+00,
       -2.85587475e-01, -1.16004796e+00,  3.11197143e-01,  4.32893956e-01,
       -1.18395481e+00,  4.15029144e-01, -3.73489833e-01, -3.89991369e-01,
       -4.07087426e-01,  6.35885165e-01,  1.17811008e+00, -7.70632214e-01,
        7.42037161e-01,  8.14596202e-01, -1.27280073e+00,  4.83662507e-02,
        7.53920189e-01, -8.73631496e-01,  1.50343973e-01,  4.48568616e-02,
       -2.36508523e-02, -6.24667498e-01,  8.18963535e-02, -1.93455445e-01,
        3.48297489e-01,  1.23500531e+00, -5.08090127e-01,  7.67151402e-01,
       -1.70114397e-01, -3.48099848e-01,  9.85840434e-01,  1.20990892e+00,
       -2.59668565e+00,  1.15757740e+00,  6.31236771e-01,  1.28040084e+00,
       -7.07427401e-01, -1.48573401e+00, -1.48176752e+00,  5.24616501e-02,
        3.60514090e-01, -5.57691042e-01, -9.08968294e-02, -5.83182639e-01,
        8.62128409e-01, -8.10917226e-02,  4.34062963e-01,  1.04056475e+00,
        1.50934936e+00, -6.44645441e-01, -1.15676960e+00,  2.20075797e+00,
       -6.34941035e-01, -2.24228478e-01,  4.60297427e+00, -2.45772048e-01,
        1.60306737e+00, -1.37782077e+00, -5.02189582e-01, -1.37643229e+00,
        9.61379772e-01, -1.41677112e+00,  8.09838037e-01,  1.07539773e-01,
        1.35445599e-01,  7.74260213e-01, -3.15793063e-01, -1.03162450e+00])

Random forest¶

In [51]:
from sklearn.ensemble import RandomForestClassifier

rf_max_accuracy = 0
for x in range(200):
    for j in range(1,20):
        rf = RandomForestClassifier(random_state=x,n_estimators= j,criterion="entropy")
        rf.fit(X_train,y_train)
        Y_pred_rf = rf.predict(X_test)
        current_accuracy = round(accuracy_score(Y_pred_rf,y_test)*100,2)
        if(current_accuracy>rf_max_accuracy):
            rf_max_accuracy = current_accuracy
            best_x = x
            best_y = j
rf_clf = RandomForestClassifier(random_state=best_x,n_estimators = best_y)
rf_clf.fit(X_train,y_train)
Y_pred_rf = rf.predict(X_test)
In [52]:
print("Accuracy of Random Forest: ",rf_max_accuracy,"%")
print("Used Random state ",best_x)
print("Number of Decision Trees used :",best_y)
Accuracy of Random Forest:  90.57 %
Used Random state  96
Number of Decision Trees used : 19

Misclassified data of Random forest¶

In [53]:
misclassified_samples_of_RF = X_test[y_test != Y_pred_rf]
print("Misclassified data : \n\n",misclassified_samples_of_RF)
Misclassified data : 

 [[ 0.          0.          1.          0.          0.          0.
   0.          1.          0.          1.         -0.48166801  0.98213704
   2.48964157  1.29783012 -0.65825447]
 [ 0.          0.          0.          0.          0.          0.
   0.          1.          1.          1.          0.54040801 -0.14934726
  -1.05007001  0.03263965  0.95442779]
 [ 1.          0.          1.          0.          0.          0.
   0.          0.          0.          1.         -1.75926304  0.41639489
  -1.41624707  1.71956028 -0.65825447]
 [ 0.          0.          0.          0.          1.          0.
   0.          0.          1.          1.          0.28488901  0.41639489
   0.59772676  0.03263965 -0.65825447]
 [ 0.          0.          1.          0.          0.          0.
   0.          1.          0.          1.          0.66816751 -0.14934726
  -1.01955525  0.37002378 -0.65825447]
 [ 1.          0.          1.          0.          0.          1.
   0.          0.          0.          0.         -0.60942751 -0.7150894
   0.6434989   0.66523489 -0.65825447]
 [ 0.          0.          1.          0.          1.          0.
   0.          0.          0.          1.          0.41264851  0.13352382
  -1.35521756  0.45436981  1.49198854]
 [ 0.          0.          1.          0.          0.          1.
   0.          1.          1.          0.          0.54040801 -1.16768312
   1.42162515 -1.82297304  0.41686703]
 [ 0.          0.          0.          0.          1.          0.
   0.          0.          0.          1.          0.1571295   1.54787919
  -1.05007001  0.7074079   0.41686703]]
In [54]:
indices_of_misclassified_samples_of_RF =np.flatnonzero(y_test != Y_pred_rf)
print(indices_of_misclassified_samples_of_RF)

predicted_misclassified_data_of_RF = rf_clf.predict(misclassified_samples_of_RF)
[ 5  6  7 24 36 41 42 50 51]
In [55]:
w=[]
for i in indices_of_misclassified_samples_of_RF:
    w.append(Y_pred_rf[i])
print(w)
m = np.array(w)

print("Actual values : ",m)
print("Predicted values : ",predicted_misclassified_data_of_RF)

rf_accuracy = accuracy_score(m,predicted_misclassified_data_of_RF)
print("Accuracy of the Random forest model is : {0:.2f}".format(rf_accuracy*100),'%')
[1, 1, 0, 0, 1, 0, 0, 0, 0]
Actual values :  [1 1 0 0 1 0 0 0 0]
Predicted values :  [0 0 0 1 1 0 0 1 0]
Accuracy of the Random forest model is : 55.56 %

TabNet model¶

installing the model¶

In [56]:
#!pip install pytorch_tabnet

Improting the model¶

In [57]:
from pytorch_tabnet.tab_model import TabNetClassifier
import torch

# define the model
tn_clf= TabNetClassifier(optimizer_fn=torch.optim.Adam,
                      scheduler_params={"step_size":20,"gamma":1.2},
                      scheduler_fn=torch.optim.lr_scheduler.StepLR)

# fit the model 
tn_clf.fit(
    X_train,y_train,
    eval_set=[(X_train, y_train), (X_test, y_test)],
    eval_name=['train', 'test'],
    eval_metric=['auc','balanced_accuracy'],
    max_epochs=53,
    patience=40,
    batch_size=203, 
    virtual_batch_size=203,
    num_workers=0,
    weights=1,
    drop_last=False
)
C:\Users\naga_\AppData\Roaming\Python\Python39\site-packages\pytorch_tabnet\abstract_model.py:75: UserWarning: Device used : cpu
  warnings.warn(f"Device used : {self.device}")
epoch 0  | loss: 0.69701 | train_auc: 0.48154 | train_balanced_accuracy: 0.48507 | test_auc: 0.38937 | test_balanced_accuracy: 0.34626 |  0:00:00s
epoch 1  | loss: 0.77254 | train_auc: 0.54921 | train_balanced_accuracy: 0.51603 | test_auc: 0.45115 | test_balanced_accuracy: 0.41236 |  0:00:00s
epoch 2  | loss: 0.61589 | train_auc: 0.55698 | train_balanced_accuracy: 0.5475  | test_auc: 0.47845 | test_balanced_accuracy: 0.48491 |  0:00:00s
epoch 3  | loss: 0.59854 | train_auc: 0.53358 | train_balanced_accuracy: 0.50696 | test_auc: 0.52011 | test_balanced_accuracy: 0.57471 |  0:00:00s
epoch 4  | loss: 0.61337 | train_auc: 0.58239 | train_balanced_accuracy: 0.58179 | test_auc: 0.54598 | test_balanced_accuracy: 0.55029 |  0:00:00s
epoch 5  | loss: 0.5371  | train_auc: 0.61527 | train_balanced_accuracy: 0.62001 | test_auc: 0.65086 | test_balanced_accuracy: 0.58836 |  0:00:00s
epoch 6  | loss: 0.4734  | train_auc: 0.6786  | train_balanced_accuracy: 0.65077 | test_auc: 0.71121 | test_balanced_accuracy: 0.68175 |  0:00:00s
epoch 7  | loss: 0.47108 | train_auc: 0.7615  | train_balanced_accuracy: 0.70442 | test_auc: 0.74138 | test_balanced_accuracy: 0.71264 |  0:00:00s
epoch 8  | loss: 0.4279  | train_auc: 0.80234 | train_balanced_accuracy: 0.74546 | test_auc: 0.71695 | test_balanced_accuracy: 0.70905 |  0:00:00s
epoch 9  | loss: 0.46938 | train_auc: 0.83905 | train_balanced_accuracy: 0.77904 | test_auc: 0.74569 | test_balanced_accuracy: 0.70905 |  0:00:00s
epoch 10 | loss: 0.41363 | train_auc: 0.8577  | train_balanced_accuracy: 0.77441 | test_auc: 0.78592 | test_balanced_accuracy: 0.75072 |  0:00:00s
epoch 11 | loss: 0.43136 | train_auc: 0.85629 | train_balanced_accuracy: 0.76765 | test_auc: 0.84339 | test_balanced_accuracy: 0.74713 |  0:00:00s
epoch 12 | loss: 0.57975 | train_auc: 0.86971 | train_balanced_accuracy: 0.75413 | test_auc: 0.85057 | test_balanced_accuracy: 0.72629 |  0:00:00s
epoch 13 | loss: 0.42757 | train_auc: 0.85962 | train_balanced_accuracy: 0.7616  | test_auc: 0.86925 | test_balanced_accuracy: 0.78161 |  0:00:00s
epoch 14 | loss: 0.40582 | train_auc: 0.87298 | train_balanced_accuracy: 0.75928 | test_auc: 0.88075 | test_balanced_accuracy: 0.77802 |  0:00:00s
epoch 15 | loss: 0.32993 | train_auc: 0.85922 | train_balanced_accuracy: 0.75252 | test_auc: 0.8592  | test_balanced_accuracy: 0.77802 |  0:00:00s
epoch 16 | loss: 0.35267 | train_auc: 0.85705 | train_balanced_accuracy: 0.76603 | test_auc: 0.86638 | test_balanced_accuracy: 0.77802 |  0:00:00s
epoch 17 | loss: 0.35237 | train_auc: 0.85891 | train_balanced_accuracy: 0.75555 | test_auc: 0.87213 | test_balanced_accuracy: 0.76078 |  0:00:00s
epoch 18 | loss: 0.28544 | train_auc: 0.86633 | train_balanced_accuracy: 0.78025 | test_auc: 0.86925 | test_balanced_accuracy: 0.79885 |  0:00:00s
epoch 19 | loss: 0.42675 | train_auc: 0.86527 | train_balanced_accuracy: 0.7495  | test_auc: 0.875   | test_balanced_accuracy: 0.79885 |  0:00:01s
epoch 20 | loss: 0.286   | train_auc: 0.87354 | train_balanced_accuracy: 0.7495  | test_auc: 0.86494 | test_balanced_accuracy: 0.77802 |  0:00:01s
epoch 21 | loss: 0.38582 | train_auc: 0.86587 | train_balanced_accuracy: 0.73528 | test_auc: 0.85057 | test_balanced_accuracy: 0.74353 |  0:00:01s
epoch 22 | loss: 0.4734  | train_auc: 0.85569 | train_balanced_accuracy: 0.74879 | test_auc: 0.8592  | test_balanced_accuracy: 0.76437 |  0:00:01s
epoch 23 | loss: 0.54582 | train_auc: 0.85115 | train_balanced_accuracy: 0.76694 | test_auc: 0.87069 | test_balanced_accuracy: 0.76796 |  0:00:01s
epoch 24 | loss: 0.44122 | train_auc: 0.84913 | train_balanced_accuracy: 0.75504 | test_auc: 0.87069 | test_balanced_accuracy: 0.73707 |  0:00:01s
epoch 25 | loss: 0.35753 | train_auc: 0.85402 | train_balanced_accuracy: 0.77531 | test_auc: 0.86494 | test_balanced_accuracy: 0.73348 |  0:00:01s
epoch 26 | loss: 0.30352 | train_auc: 0.85987 | train_balanced_accuracy: 0.79185 | test_auc: 0.87213 | test_balanced_accuracy: 0.80963 |  0:00:01s
epoch 27 | loss: 0.43136 | train_auc: 0.87011 | train_balanced_accuracy: 0.81515 | test_auc: 0.87069 | test_balanced_accuracy: 0.81322 |  0:00:01s
epoch 28 | loss: 0.42508 | train_auc: 0.8693  | train_balanced_accuracy: 0.79256 | test_auc: 0.89511 | test_balanced_accuracy: 0.81322 |  0:00:01s
epoch 29 | loss: 0.42671 | train_auc: 0.86456 | train_balanced_accuracy: 0.78046 | test_auc: 0.93103 | test_balanced_accuracy: 0.86135 |  0:00:01s
epoch 30 | loss: 0.42027 | train_auc: 0.86426 | train_balanced_accuracy: 0.76977 | test_auc: 0.93678 | test_balanced_accuracy: 0.84052 |  0:00:01s
epoch 31 | loss: 0.40137 | train_auc: 0.87132 | train_balanced_accuracy: 0.7735  | test_auc: 0.94253 | test_balanced_accuracy: 0.84052 |  0:00:01s
epoch 32 | loss: 0.39828 | train_auc: 0.87631 | train_balanced_accuracy: 0.76745 | test_auc: 0.9181  | test_balanced_accuracy: 0.84052 |  0:00:01s
epoch 33 | loss: 0.46473 | train_auc: 0.88009 | train_balanced_accuracy: 0.78399 | test_auc: 0.91667 | test_balanced_accuracy: 0.80244 |  0:00:01s
epoch 34 | loss: 0.44311 | train_auc: 0.88796 | train_balanced_accuracy: 0.77582 | test_auc: 0.92241 | test_balanced_accuracy: 0.84411 |  0:00:01s
epoch 35 | loss: 0.32967 | train_auc: 0.89129 | train_balanced_accuracy: 0.77441 | test_auc: 0.91236 | test_balanced_accuracy: 0.82328 |  0:00:01s
epoch 36 | loss: 0.46221 | train_auc: 0.905   | train_balanced_accuracy: 0.77067 | test_auc: 0.89655 | test_balanced_accuracy: 0.84411 |  0:00:01s
epoch 37 | loss: 0.36585 | train_auc: 0.89794 | train_balanced_accuracy: 0.76694 | test_auc: 0.89368 | test_balanced_accuracy: 0.84411 |  0:00:01s
epoch 38 | loss: 0.37152 | train_auc: 0.89845 | train_balanced_accuracy: 0.77814 | test_auc: 0.89511 | test_balanced_accuracy: 0.86135 |  0:00:01s
epoch 39 | loss: 0.41971 | train_auc: 0.90702 | train_balanced_accuracy: 0.77441 | test_auc: 0.92529 | test_balanced_accuracy: 0.86135 |  0:00:02s
epoch 40 | loss: 0.32774 | train_auc: 0.89794 | train_balanced_accuracy: 0.76089 | test_auc: 0.94397 | test_balanced_accuracy: 0.84052 |  0:00:02s
epoch 41 | loss: 0.34285 | train_auc: 0.89744 | train_balanced_accuracy: 0.7856  | test_auc: 0.94253 | test_balanced_accuracy: 0.79885 |  0:00:02s
epoch 42 | loss: 0.42835 | train_auc: 0.88972 | train_balanced_accuracy: 0.78792 | test_auc: 0.94253 | test_balanced_accuracy: 0.80244 |  0:00:02s
epoch 43 | loss: 0.34187 | train_auc: 0.88604 | train_balanced_accuracy: 0.79256 | test_auc: 0.93247 | test_balanced_accuracy: 0.82687 |  0:00:02s
epoch 44 | loss: 0.37511 | train_auc: 0.88418 | train_balanced_accuracy: 0.80163 | test_auc: 0.93103 | test_balanced_accuracy: 0.85129 |  0:00:02s
epoch 45 | loss: 0.35731 | train_auc: 0.87192 | train_balanced_accuracy: 0.80163 | test_auc: 0.93247 | test_balanced_accuracy: 0.8477  |  0:00:02s
epoch 46 | loss: 0.37645 | train_auc: 0.85801 | train_balanced_accuracy: 0.79256 | test_auc: 0.92385 | test_balanced_accuracy: 0.86135 |  0:00:02s
epoch 47 | loss: 0.3999  | train_auc: 0.85902 | train_balanced_accuracy: 0.79256 | test_auc: 0.92672 | test_balanced_accuracy: 0.86135 |  0:00:02s
epoch 48 | loss: 0.42497 | train_auc: 0.86295 | train_balanced_accuracy: 0.79931 | test_auc: 0.91667 | test_balanced_accuracy: 0.84411 |  0:00:02s
epoch 49 | loss: 0.42232 | train_auc: 0.87394 | train_balanced_accuracy: 0.80022 | test_auc: 0.91954 | test_balanced_accuracy: 0.80603 |  0:00:02s
epoch 50 | loss: 0.41104 | train_auc: 0.88433 | train_balanced_accuracy: 0.8093  | test_auc: 0.90517 | test_balanced_accuracy: 0.78879 |  0:00:02s
epoch 51 | loss: 0.42426 | train_auc: 0.8923  | train_balanced_accuracy: 0.80557 | test_auc: 0.90086 | test_balanced_accuracy: 0.79598 |  0:00:02s
epoch 52 | loss: 0.40706 | train_auc: 0.89729 | train_balanced_accuracy: 0.80204 | test_auc: 0.91379 | test_balanced_accuracy: 0.83764 |  0:00:02s
Stop training because you reached max_epochs = 53 with best_epoch = 29 and best_test_balanced_accuracy = 0.86135
C:\Users\naga_\AppData\Roaming\Python\Python39\site-packages\pytorch_tabnet\callbacks.py:172: UserWarning: Best weights from best epoch are automatically used!
  warnings.warn(wrn_msg)
In [58]:
tn_predicted=tn_clf.predict(X_test)
In [59]:
error=0
error = mean_squared_error(y_test,tn_predicted)
print('Mean squred error is : ',error)
tn_accuracy = accuracy_score(y_test,tn_predicted)
print("Accuracy of the TabNet model is : {0:.2f}".format(tn_accuracy*100),'%')
Mean squred error is :  0.1320754716981132
Accuracy of the TabNet model is : 86.79 %

Accuracy measurements¶

In [60]:
print("Accuracy of the SVM model is : {0:.2f}".format(svm_accuracy*100),'%')
print("Accuracy of Random Forest: ",rf_max_accuracy,"%")
print("Accuracy of the TabNet model is : {0:.2f}".format(tn_accuracy*100),'%')
Accuracy of the SVM model is : 84.91 %
Accuracy of Random Forest:  90.57 %
Accuracy of the TabNet model is : 86.79 %

Dynamic Input¶

In [61]:
p_age = int(input("Enter Patient Age(in number) :"))
p_trestbps = int(input("Enter Patient trestbps value: "))
p_chol = int(input("Enter Patient Cholesterol: "))
p_thalach = int(input("Enter maximum heart rate: "))
p_oldpeak = int(input("Enter patient oldpeak value: "))

p_fbs = int(input("Enter fbs value\n0 is for normal value\n1 is for abnormal value\n: "))
p_cp = int(input("Enter the Chest pain type:\n1 is for Typical\n2 is for Atypical\n3 is for Non-Anginal pain\n4 is for Asymptomatic :"))
p_exang = int(input("Enter exang value (1 is for YES\n 0 is for NO\n): "))
p_sex = int(input("Enter gender of patient \n(1 is for Male\n 2 is for Female : )"))
p_restecg = int(input("Enter restecg value:\n 0 is for normal\n 1 is for havig ST-T\n 2 is for hypertrophy :" ))
Enter Patient Age(in number) :28
Enter Patient trestbps value: 130
Enter Patient Cholesterol: 132
Enter maximum heart rate: 185
Enter patient oldpeak value: 0
Enter fbs value
0 is for normal value
1 is for abnormal value
: 0
Enter the Chest pain type:
1 is for Typical
2 is for Atypical
3 is for Non-Anginal pain
4 is for Asymptomatic :2
Enter exang value (1 is for YES
 0 is for NO
): 0
Enter gender of patient 
(1 is for Male
 2 is for Female : )1
Enter restecg value:
 0 is for normal
 1 is for havig ST-T
 2 is for hypertrophy :2
In [62]:
p_numerical_cols=[[p_age,p_trestbps,p_chol,p_thalach,p_oldpeak]]
if p_cp==1:
    p_cp_1=True
    p_cp_2=False
    p_cp_3=False
    p_cp_4=False
elif p_cp==2:
    p_cp_1=False
    p_cp_2=True
    p_cp_3=False
    p_cp_4=False
elif p_cp==3:
    p_cp_1=False
    p_cp_2=False
    p_cp_3=True
    p_cp_4=False
elif p_cp==4:
    p_cp_1=False
    p_cp_2=False
    p_cp_3=True
    p_cp_4=False
else:
    p_cp_1=False
    p_cp_2=False
    p_cp_3=False
    p_cp_4=False
if p_restecg==0:
    p_restecg_0=True
    p_restecg_1=False
    p_restecg_2=False
elif p_restecg==1:
    p_restecg_0=False
    p_restecg_1=True
    p_restecg_2=False
elif p_restecg==2:
    p_restecg_0=False
    p_restecg_1=False
    p_restecg_2=True
p_cat_cols=[[p_fbs, p_cp_2, p_cp_3, p_exang, p_cp_1, p_sex, p_cp_4, p_restecg_0, p_restecg_1, p_restecg_2]]
In [63]:
print("Numerical values: ",p_numerical_cols)
Numerical values:  [[28, 130, 132, 185, 0]]
In [64]:
print("Categorical values: ",p_cat_cols)
Categorical values:  [[0, True, False, 0, False, 1, False, False, False, True]]
In [65]:
dummy_cat1=[[0,False,False,0,True,0,False,False,True,False]]
dummy_cat2=[[0,True,False,0,False,0,False,True,False,False]]
dummy_num1=[[30,170,237,170,0]]
dummy_num2=[[32,105,198,165,0]]
p_cate_cols=p_cat_cols+dummy_cat1+dummy_cat2
p_numeri_cols=p_numerical_cols+dummy_num1+dummy_num2
print(p_cate_cols)
print(p_numeri_cols)
[[0, True, False, 0, False, 1, False, False, False, True], [0, False, False, 0, True, 0, False, False, True, False], [0, True, False, 0, False, 0, False, True, False, False]]
[[28, 130, 132, 185, 0], [30, 170, 237, 170, 0], [32, 105, 198, 165, 0]]
In [66]:
p_numeri_cols=np.array(p_numeri_cols)
p_cate_cols=np.array(p_cate_cols)
p_cate_cols.reshape(3,10)
p_numeri_cols.reshape(3,5)
print(p_cate_cols)
print(p_numeri_cols)
[[0 1 0 0 0 1 0 0 0 1]
 [0 0 0 0 1 0 0 0 1 0]
 [0 1 0 0 0 0 0 1 0 0]]
[[ 28 130 132 185   0]
 [ 30 170 237 170   0]
 [ 32 105 198 165   0]]
In [67]:
def my_fun(p_numeri_cols,p_cate_cols,scaler):
    p_x_scaled = scaler.fit_transform(p_numeri_cols)
    p_x_cat = p_cate_cols
    p_x = np.hstack((p_x_cat,p_x_scaled))
    return p_x
p_data_x = my_fun(p_numeri_cols,p_cate_cols,scaler)
In [68]:
p_data_x
Out[68]:
array([[ 0.        ,  1.        ,  0.        ,  0.        ,  0.        ,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ,
        -1.22474487, -0.18677184, -1.31530679,  1.37281295,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ,  1.30740289,  1.10762677, -0.39223227,  0.        ],
       [ 0.        ,  1.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
         1.22474487, -1.12063105,  0.20768002, -0.98058068,  0.        ]])
In [69]:
print(p_data_x.shape)
(3, 15)
In [70]:
predicted_op_of_Random_forest=rf_clf.predict(p_data_x)
predicted_op_of_TabNet=tn_clf.predict(p_data_x)
predicted_op_of_Support_Vector_Machine=svm_clas.predict(p_data_x)
In [71]:
print("Predicted output of Random_forest:",predicted_op_of_Random_forest[0])
print("Predicted output of TabNet:",predicted_op_of_TabNet[0])
print("Predicted output of SVM:",predicted_op_of_Support_Vector_Machine[0])
Predicted output of Random_forest: 0
Predicted output of TabNet: 1
Predicted output of SVM: 0
In [77]:
count1 = 0
if(predicted_op_of_Support_Vector_Machine[0]!=None):
    print("The SVM detects that, ",end="")
    if predicted_op_of_Support_Vector_Machine[0]==0:
        print("The patient seems to be NORMAL\n")
    else:
        count1+=1
        print("The patient have the RISK to get heart disease,please conduct a doctor immediately\n\n")
if(predicted_op_of_Random_forest[0]!=None):
    print("The Random Forest detects that, ",end="")
    if predicted_op_of_Random_forest[0]==0:
        print("The patient seems to be NORMAL\n")
    else:
        count1+=1
        print("The patient have the RISK to get heart disease,please conduct a doctor immediately\n\n")
if(predicted_op_of_TabNet[0]!=None):
    print("The TabNet detects that, ",end="")
    if predicted_op_of_TabNet[0]==0:
        print("The patient seems to be NORMAL\n")
    else:
        count1+=1
        print("The patient have the RISK to get heart disease,please conduct a doctor immediately\n\n")
print("Final conclusion:")        
if count1>=2:
    print("The patirnt have RISK to get Heart Disease")
else:
     print("The patient seems to be NORMAL,Take care of your Health")
The SVM detects that, The patient seems to be NORMAL

The Random Forest detects that, The patient seems to be NORMAL

The TabNet detects that, The patient have the RISK to get heart disease,please conduct a doctor immediately


Final conclusion:
The patient seems to be NORMAL,Take care of your Health

Confusion matrix & Classification Report¶

In [73]:
from sklearn import metrics
print("Confusion matrix for SVM : ")
confusion_matrix = metrics.confusion_matrix(y_test, y_predict)

cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix, display_labels = [False, True])
cm_display.plot()
plt.show()
Confusion matrix for SVM : 
In [74]:
from sklearn import metrics
print("Confusion matrix for Random Forest : ")
confusion_matrix = metrics.confusion_matrix(y_test,Y_pred_rf)
cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix, display_labels = [False, True])
cm_display.plot()
plt.show()
Confusion matrix for Random Forest : 
In [75]:
from sklearn import metrics
print("Confusion matrix for TabNet : ")
confusion_matrix = metrics.confusion_matrix(y_test,tn_predicted)
cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix, display_labels = [False, True])
cm_display.plot()
plt.show()
Confusion matrix for TabNet : 
In [76]:
from sklearn.metrics import classification_report
print("SVM Classification Report :\n",classification_report(y_test,y_predic),"\n")
print("Random Forest Classification Report :\n",classification_report(y_test,Y_pred_rf),"\n")
print("TabNet Classification Report :\n",classification_report(y_test,tn_predicted),"\n")
SVM Classification Report :
               precision    recall  f1-score   support

           0       0.80      0.97      0.88        29
           1       0.94      0.71      0.81        24

    accuracy                           0.85        53
   macro avg       0.87      0.84      0.84        53
weighted avg       0.87      0.85      0.85        53
 

Random Forest Classification Report :
               precision    recall  f1-score   support

           0       0.81      0.90      0.85        29
           1       0.86      0.75      0.80        24

    accuracy                           0.83        53
   macro avg       0.83      0.82      0.83        53
weighted avg       0.83      0.83      0.83        53
 

TabNet Classification Report :
               precision    recall  f1-score   support

           0       0.84      0.93      0.89        29
           1       0.90      0.79      0.84        24

    accuracy                           0.87        53
   macro avg       0.87      0.86      0.86        53
weighted avg       0.87      0.87      0.87        53
 

In [ ]: